from gym_minigrid.minigrid import *
from envs.minigrid.minigrid_extensions import *
from rm_wrappers import RewardMachineEnv, RewardMachineNoisyBeliefUpdateEnv
from random import randint

class KitchenEnv(MiniGridEnv):
    """
    An environment where both thresholding and belief update will fail.
    """
    def __init__(
        self,
        size=9,
        agent_start_pos=None,
        agent_start_dir=0,
        timeout=100,
        is_locked=True,
        randomize_chores=False,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.event_objs = []
        self.target_pos = (1, 1)
        self.is_locked = is_locked
        self.randomize_chores=randomize_chores

        super().__init__(
            grid_size=size,
            max_steps=4*size*size,
            # Set this to True for maximum speed
            see_through_walls=False,
            agent_view_size=7
        )

    def reset(self):
        self.events = ""
        return super().reset()

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Generate inner walls
        self.grid.vert_wall(4, 0)
        self.num_doors = 3
        self.door_poss = [(4, 2), (4, 4), (4, 6)]
        self.before_door_poss = [(i-1, j) for (i, j) in self.door_poss]
        self.before_door_poss.extend([(i+1, j) for (i, j) in self.door_poss])
        self.doors = [Door('yellow', is_locked=self.is_locked) for _ in range(self.num_doors)]
        for i in range(self.num_doors):
            door_pos = self.door_poss[i]
            self.put_obj(self.doors[i], *door_pos)

        self.put_obj(Goal(), *self.target_pos)


        # Place a goal square in the bottom-right corner
        if self.randomize_chores:
            self.chore_poss = []
            chore_left, chore_right, chore_up, chore_down = 5, 7, 1, 7
            while len(self.chore_poss) < 3:
                x = randint(chore_left, chore_right)
                y = randint(chore_up, chore_down)
                if (x, y) not in self.chore_poss:
                    self.chore_poss.append((x,y))
        else:
            self.chore_poss = [(5, 7), (7, 7), (7, 1)]
        self.chores = [Floor('green'), Floor('blue'), Floor('red')]
        self.event_objs = []
        self.event_objs.append((self.chore_poss[0], 'a'))
        self.event_objs.append((self.chore_poss[1], 'b'))
        self.event_objs.append((self.chore_poss[2], 'c'))
        self.event_objs.append((self.target_pos, 'd'))

        # Randomly complete each goal with 1/3 probability
        for i, pos in enumerate(self.chore_poss):
            done = (randint(0,2) == 0)
            if done:
                self.chores[i].color = 'grey'
                if self.event_objs[i][1] not in self.events:
                    self.events += self.event_objs[i][1]


        #self.start_rm_u_id = drop_idx + 1  # a is done, then u_id = 1. etc.

        for pos, obj in zip(self.chore_poss, self.chores):
            self.put_obj(obj, *pos)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent(top=(1,1), size=(3,7))

        self.mission = "do two of three chores and exit the room to get to the target"

    def step(self, action):
        penalty = 0
        # Automatically unlock the door
        if action == self.actions.forward and self.agent_dir == 0:
            for i in range(self.num_doors):
                if tuple(self.agent_pos) == self.before_door_poss[i] and not self.doors[0].is_open:
                    penalty += 0.05
                    for j in range(self.num_doors):
                        self.doors[j].is_open = True
                        self.doors[j].is_locked = False

        next_state, _, done, info = super().step(action)

        for i, pos in enumerate(self.chore_poss):
            if tuple(self.agent_pos) == pos and action == self.actions.forward:
                penalty += 0.05
                self.chores[i].color = 'grey'
                if self.event_objs[i][1] not in self.events:
                    self.events += self.event_objs[i][1]
        return next_state, -penalty, done, info

    def get_events(self):
        for pos, event in self.event_objs:
            if tuple(self.agent_pos) == pos and event not in self.events:
                self.events += event
        return self.events

    def get_sync_rm_func(self):
        """
        Synchronize the underlying RM state with the randomly initialized kitchen domain. Return nothing.
        """
        # def func(rm_env:RewardMachineEnv):
        #     rm_env.current_u_id = self.start_rm_u_id
        #     return
        return None

    def get_sync_rm_belief_func(self):
        """
        Synchronize the agent's prior over the initial RM states
        """
        # def func(rm_env:RewardMachineNoisyBeliefUpdateEnv):
        #     belief_u_dist = [0.] * rm_env.num_rm_states
        #     start_u_ids = (1, 2, 3)
        #     start_probs = 1./3.
        #     for u_id in start_u_ids:
        #         belief_u_dist[u_id] = start_probs
        #     rm_env.belief_u_dist = belief_u_dist
        #     return
        return None